import random
import torch
import numpy as np
from torchvision.transforms import functional as F
from torchvision import transforms
import math
from PIL import Image
import cv2

def _flip_coco_person_keypoints(kps, width):
    flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
    flipped_data = kps[:, flip_inds]
    flipped_data[..., 0] = width - flipped_data[..., 0]
    # Maintain COCO convention that if visibility == 0, then x, y = 0
    inds = flipped_data[..., 2] == 0
    flipped_data[inds] = 0
    return flipped_data


class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


class RandomHorizontalFlip(object):
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
            if "masks" in target:
                target["masks"] = target["masks"].flip(-1)
            if "keypoints" in target:
                keypoints = target["keypoints"]
                keypoints = _flip_coco_person_keypoints(keypoints, width)
                target["keypoints"] = keypoints
        return image, target

# class RandomCrop(object):
#
#     def __init__(self,prob,small_threshold):
#         self.prob = prob
#         self.small_threshold = small_threshold
#
#     def __call__(self, image, target):
#         if random.random() < self.prob:
#             imh = image.size(1)
#             imw = image.size(2)
#             short_size = min(imh, imw)
#             while True:
#                 # 选择任意一个crop pitch
#                 for _ in range(10):
#                     w = random.randrange(int(0.3 * short_size), short_size)  # 从0.3倍的最小边界开始
#                     h = int(w * 720 / 1280)
#
#                     x = random.randint(0, imw - w)
#                     y = random.randint(0, imh - h)  # 随机选择正方形区域
#                     roi = torch.Tensor([x, y, x + w, y + h])
#                     target["boxes"] = target["boxes"].numpy()  # 同理，tensor2numpy
#                     boxes = target["boxes"].copy()
#                     center = (boxes[:, :2] + boxes[:, 2:]) / 2  # (N,2)维, 中心点
#                     roi2 = roi.expand(len(center), 4)  # 把(1,4)expand到(N,4)维
#                     # 1. 选择包含box的roi
#                     mask = (center > roi2[:, :2]) & (center < roi2[:, 2:])  # crop pitch里面包含那个中心点(N,2)
#                     mask = mask[:, 0] & mask[:, 1]  # (N, 1)
#                     print("mask" + mask)
#                     # any对每个元素进行或运算，对每个元素进行与运算
#                     if not mask.any():  # 如果全为零，舍弃这个crop patch
#                         # im, boxes, labels = self.random_getim()
#                         # imh, imw, _ = im.shape
#                         # short_size = min(imh, imw)
#                         continue
#
#                     selected_boxes = boxes.index_select(0, mask.nonzero().squeeze(1))  # mask变为(N,)
#                     img = image[y:y + h, x:x + w]  # 裁剪区域
#                     selected_boxes[:, 0].add_(-x).clamp_(min=0, max=w)  # clamp 夹并在x,y之间
#                     selected_boxes[:, 1].add_(-y).clamp_(min=0, max=h)
#                     selected_boxes[:, 2].add_(-x).clamp_(min=0, max=w)
#                     selected_boxes[:, 3].add_(-y).clamp_(min=0, max=h)
#
#                     # expand_as(x) 表示扩展成x的尺寸
#                     boxes_uniform = selected_boxes / torch.Tensor([w, h, w, h]).expand_as(selected_boxes)
#                     boxeswh = boxes_uniform[:, 2:] - boxes_uniform[:, :2]
#
#                     # 2. 选择去掉box太小的
#                     mask = (boxeswh[:, 0] > self.small_threshold) & (boxeswh[:, 1] > self.small_threshold)
#                     print(mask)
#                     if not mask.any():  # 若全部为零，则舍弃
#                         im, boxes, labels = self.random_getim()
#                         imh, imw, _ = im.shape
#                         short_size = min(imh, imw)
#                         continue
#
#                     selected_boxes_selected = selected_boxes[mask.nonzero().squeeze(1)]  # mask变为(N,)
#                     selected_labels = labels.index_select(0, mask.nonzero().squeeze(1))
#                     return img, selected_boxes_selected, selected_labels
#         return image, target

# class RandomCrop(object):
#     def __init__(self,prob=0.1):
#         self.prob = prob
#
#     def __call__(self, image, target):
#         if random.random() < self.prob:
#             height = image.size(1)
#             width = image.size(2)
#             while True:
#                 # 迭代50次，保证能找到一个有目标的crop pitch
#                 for _ in range(50):
#                     current_image = image
#                     # 宽随机采样，高度跟随变化
#                     w = random.uniform(0.3 * width, width)
#                     h = w * 720 / 1280
#                     # # 宽高比例不当
#                     # if h / w < 0.5 or h / w > 2:
#                     #     continue
#                     left = random.uniform(.0, width - w)
#                     top = random.uniform(.0, height - h)
#                     # 框坐标x1,y1,x2,y2
#                     rect = np.array([int(left), int(top), int(left + w), int(top + h)])
#                     # # 求iou
#                     # overlap = iou(boxes, rect)
#                     # if overlap.min() < min_iou and max_iou < overlap.max():
#                     #     continue
#                     # 裁剪图像
#                     current_image = current_image[:, rect[1]: rect[3], rect[0]: rect[2]]
#                     # 中心点坐标
#                     print(target["boxes"])
#                     target["boxes"] = target["boxes"].numpy()  # 同理，tensor2numpy
#                     print(target["boxes"])
#                     centers = (target["boxes"][:, :2] + target["boxes"][:, 2:]) / 2.0
#                     # print(centers)
#                     m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])
#                     m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])
#                     # 当m1和m2均为正时才保留
#                     mask = m1 * m2
#                     if not mask.any():
#                         continue
#                     current_boxes = target["boxes"][mask, :].copy()
#                     current_labels = target["labels"][mask]
#                     # 根据图像变换调整box
#                     current_boxes[:, :2] = np.maximum(current_boxes[:, :2], rect[:2])
#                     current_boxes[:, :2] -= rect[:2]
#                     current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], rect[2:])
#                     current_boxes[:, 2:] -= rect[:2]
#                     target["boxes"] = current_boxes
#                     target["labels"] = current_labels
#                     image = current_image
#         return image, target

# class Expand(object):
#     def __init__(self, mean, prob):
#         self.prob = prob
#         self.mean = mean
#
#     def __call__(self, image, target):
#         if random.random() < self.prob:
#
#             # 获取图像的各个维度
#             depth = image.size(0)
#             height = image.size(1)
#             width = image.size(2)
#             # print(type(target))
#             # 随机缩放尺度
#             ratio = random.uniform(1, 4)
#             left = random.uniform(0, width * ratio - width)
#             top = random.uniform(0, height * ratio - height)
#             # 确定缩放后的图像的维度
#             # expand_image = np.zeros((int(height * ratio), int(width * ratio), depth))
#             expand_image = torch.zeros((depth, int(height * ratio), int(width * ratio)), dtype=torch.float32)
#             expand_image[:, :, :] = self.mean
#             expand_image[:, int(top): int(top + height), int(left): int(left + width)] = image
#             # 返回缩放后的图像
#             image = expand_image
#             # 将边界框以同等方式缩放
#             target["boxes"] = target["boxes"].numpy()  # 同理，tensor2numpy
#             boxes = target["boxes"].copy()
#             print(boxes)
#             boxes[:, :2] += (int(left), int(top))
#             boxes[:, 2:] += (int(left), int(top))
#             target["boxes"] = boxes
#             target["boxes"] = torch.as_tensor(target["boxes"], dtype=torch.float32)
#             print(type(target["boxes"]))
#             # 返回
#         return image, target


# 随机放大
class RandomCrop(object):
    def __init__(self, prob, threshold):
        self.prob = prob
        self.threshold = threshold

    def __call__(self, image, target):
        if random.random() < self.prob:
            # 获取图像的各个维度,这里是tensor使用size，numpy使用shape
            depth = image.size(0)
            height = image.size(1)
            width = image.size(2)  # 原始尺寸
            # 随机放大缩放尺度，threshold保证不要太小
            ratio = random.uniform(1.0, self.threshold)
            h = int(height * ratio)
            w = int(width * ratio)  # 放大之后的尺寸，这个更大

            # 随机一个左下角点横坐标
            left = int(random.uniform(0.0, float(w - width)))  # 只需要选一个左下角点，这里只需要选一个横坐标

            # 做尺度变化，先转成opencv格式，再转回tensor格式
            image = image.numpy() * 255
            image = image.astype("uint8")
            image = np.transpose(image, (1, 2, 0))  # 转换颜色空间
            # 此时和cv2.imread后的图片格式一致，可以调用各种cv2的现成操作
            image = cv2.resize(image, (int(w), int(h)), interpolation=cv2.INTER_AREA)
            image = image[h - height:h, left:(left + width)]  # 直接切片裁切图片 裁切后的尺寸为weight * height，即原始尺寸

            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)  # 把图片的颜色空间转换回去
            # 转换回去，注意维度顺序不同，且tensor中RGB进行了归一化
            image = np.transpose(image, (2, 0, 1))
            image = image.astype("float32")
            image /= 255
            image = torch.from_numpy(image)

            # 处理boxes
            # 将边界框以同等方式缩放，先找到boxes的left和top，再进行缩放
            target["boxes"] = target["boxes"].numpy()  # 同理，tensor2numpy
            boxes = target["boxes"].copy()

            # 取标签
            labels = target['labels'].numpy().copy()
            # labels = labels.copy()  # 标签也要重新

            # 取信息，这些信息就是普通的列表
            width_boxes = boxes[:, 2] - boxes[:, 0]
            height_boxes = boxes[:, 3] - boxes[:, 1]
            top_boxes = boxes[:, 1]
            left_boxes = boxes[:, 0]
            # print(left_boxes)
            # 比例变换
            # offset_y =
            left_boxes *= ratio
            # top_boxes *= ratio
            top_boxes = height - (height - top_boxes) * ratio  # 纵向位置纠正
            w_boxes = width_boxes * ratio
            h_boxes = height_boxes * ratio

            # 整理一下需要放回去的信息，变为四边的形式
            offset_left_boxes = left_boxes - left
            offset_top_boxes = top_boxes
            offset_right_boxes = offset_left_boxes + w_boxes
            offset_bottom_boxes = offset_top_boxes + h_boxes

            # 处理boxes框超出边界
            idx_to_delete = []
            for idx in range(len(offset_right_boxes) - 1, -1, -1):  # 倒序遍历每一个数，因为要删除
                # 只考虑左右超界情况，上下不可能超界
                if offset_left_boxes[idx] <= 0:
                    offset_left_boxes[idx] = 1
                    # 舍弃宽度小于75的框
                    t_width = offset_right_boxes[idx] - offset_left_boxes[idx]
                    if t_width < 30:
                        idx_to_delete.append(idx)
                        offset_right_boxes = np.delete(offset_right_boxes, idx)
                        offset_left_boxes = np.delete(offset_left_boxes, idx)
                        offset_top_boxes = np.delete(offset_top_boxes, idx)
                        offset_bottom_boxes = np.delete(offset_bottom_boxes, idx)
                if offset_right_boxes[idx] >= 1280:
                    offset_right_boxes[idx] = 1279
                    # 舍弃宽度小于75的框
                    t_width = offset_right_boxes[idx] - offset_left_boxes[idx]
                    if t_width < 30:
                        idx_to_delete.append(idx)
                        offset_right_boxes = np.delete(offset_right_boxes, idx)
                        offset_left_boxes = np.delete(offset_left_boxes, idx)
                        offset_top_boxes = np.delete(offset_top_boxes, idx)
                        offset_bottom_boxes = np.delete(offset_bottom_boxes, idx)


            # 生成新的boxes
            new_boxes = []
            for idx in range(len(offset_left_boxes)):
                left = offset_left_boxes[idx]
                right = offset_right_boxes[idx]
                top = offset_top_boxes[idx]
                bottom = offset_bottom_boxes[idx]
                new_boxes.append([left, top, right, bottom])
            # target转换回tensor形式,直接放回去
            new_boxes = torch.as_tensor(new_boxes, dtype=torch.float32)  # boxes的dtype必须是flaot
            target["boxes"] = new_boxes

            # 新area
            area = ((new_boxes[:, 3] - new_boxes[:, 1]) * (new_boxes[:, 2] - new_boxes[:, 0]))
            target['area'] = area

            # 新labels
            new_labels = np.delete(labels, idx_to_delete)
            new_labels = torch.as_tensor(new_labels, dtype=torch.int64)
            target['labels'] = new_labels

        return image, target


# 随机缩小（扩边）
class RandomPadding(object):
    def __init__(self, mean, prob, threshold):
        self.prob = prob
        self.mean = mean
        self.threshold = threshold

    def __call__(self, image, target):
        if random.random() < self.prob:
            # 获取图像的各个维度,这里是tensor使用size，numpy使用shape
            depth = image.size(0)
            height = image.size(1)
            width = image.size(2)
            # 随机缩放尺度，threshold保证不要太小
            ratio = random.uniform(self.threshold, 1.0)
            h = int(height * ratio)
            w = int(width * ratio)
            # 做尺度变化，先转成opencv格式，再转回tensor格式
            image = image.numpy() * 255
            image = image.astype("uint8")
            image = np.transpose(image, (1, 2, 0))
            # 此时和cv2.imread后的图片格式一致，可以调用各种cv2的现成操作
            image = cv2.resize(image, (int(w), int(h)), interpolation=cv2.INTER_AREA)
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
            # 转换回去，注意维度顺序不同，且tensor中RGB进行了归一化
            image = np.transpose(image, (2, 0, 1))
            image = image.astype("float32")
            image /= 255
            image = torch.from_numpy(image)

            # 随机一个缩小后图片的左上顶点
            left = int(random.uniform(0., float(width - w)))
            top = int(random.uniform(0., float(height - h)))
            # 做一个expand_image，放入缩小后的图片，其他地方用黑底填充，expand_image大小和原图片尺寸一致
            expand_image = torch.zeros((depth, int(height), int(width)), dtype=torch.float32)
            expand_image[:, :, :] = self.mean
            expand_image[:, int(top): int(top + h), int(left): int(left + w)] = image
            # 返回缩放后的图像
            image = expand_image
            # 将边界框以同等方式缩放，先找到boxes的left和top，再进行缩放
            target["boxes"] = target["boxes"].numpy()  # 同理，tensor2numpy
            boxes = target["boxes"].copy()
            width_boxes = boxes[:, 2] - boxes[:, 0]
            height_boxes = boxes[:, 3] - boxes[:, 1]
            left_boxes = boxes[:, 0]
            left_boxes *= ratio
            top_boxes = boxes[:, 1]
            top_boxes *= ratio
            w_boxes = width_boxes * ratio
            h_boxes = height_boxes * ratio
            boxes[:, 0] = left + left_boxes
            boxes[:, 1] = top + top_boxes
            boxes[:, 2] = boxes[:, 0] + w_boxes
            boxes[:, 3] = boxes[:, 1] + h_boxes
            # target转换回tensor形式
            target["boxes"] = boxes
            target["boxes"] = torch.as_tensor(target["boxes"], dtype=torch.float32)
        return image, target


# 随机遮挡
class RandomCutout(object):
    # 三个超参数n_hloes 和length
    # n_holes是遮挡的个数
    # length_h, length_w是遮挡的大小
    # length大约取100
    def __init__(self, n_holes, length_h, length_w, prob):
        self.n_holes = n_holes
        self.length_h = length_h
        self.length_w = length_w
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            h = image.size(1)
            w = image.size(2)

            mask = torch.ones(image.size(), dtype=torch.float32)
            # print(mask)
            for n in range(self.n_holes):
                y = np.random.randint(h)
                x = np.random.randint(w)

                y1 = np.clip(y - self.length_h // 2, 0, h)
                y2 = np.clip(y + self.length_h // 2, 0, h)
                x1 = np.clip(x - self.length_w // 2, 0, w)
                x2 = np.clip(x + self.length_w // 2, 0, w)
                mask[:, y1: y2, x1: x2] = torch.full_like(mask[:, y1: y2, x1: x2], 0.)
                mask[:, y1: y2, x1: x2][0] = 0.6902         # R
                mask[:, y1: y2, x1: x2][1] = 0.5647         # G
                mask[:, y1: y2, x1: x2][2] = 0.4157         # B

            image[:, y1: y2, x1: x2] = mask[:, y1: y2, x1: x2]
            # image = image.astype(np.uint8)
            # image = Image.fromarray(image)
            # print(type(image))
        return image, target


class Normalize(object):
    """Normalize a tensor image with mean and standard deviation.
    Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
    channels, this transform will normalize each channel of the input
    ``torch.*Tensor`` i.e.,
    ``output[channel] = (input[channel] - mean[channel]) / std[channel]``

    .. note::
        This transform acts out of place, i.e., it does not mutate the input tensor.

    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
        inplace(bool,optional): Bool to make this operation in-place.

    """

    def __init__(self, mean, std, inplace=False):
        super().__init__()
        self.mean = mean
        self.std = std
        self.inplace = inplace

    def __call__(self, tensor, target):
        """
        Args:
            tensor (Tensor): Tensor image to be normalized.

        Returns:
            Tensor: Normalized Tensor image.
        """
        image = F.normalize(tensor, self.mean, self.std, self.inplace)
        return image, target

    # def __repr__(self):
    #     return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


class ToTensor(object):
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target
